# -*- coding: utf-8 -*-
import os 
import warnings
import numpy as np
import pandas as pd
from sklearn import preprocessing
from sklearn.preprocessing import LabelBinarizer
from sklearn.cluster import DBSCAN
import pickle
import seaborn as sns
import matplotlib.pyplot as plt
from read_data import read_data_from_path
from read_data import plot_cluster
from read_data import plot_surface
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from collections import Counter

warnings.filterwarnings("ignore")   #不显示警告


def data_table(cdn_data, ener_arr, ener_plot=False):
    '''
    将 cdn_data 
    '''
    columns = []
    # 弄一个装逼的列名
    for i in range(20):
        for j in ('x','y','z'):
            columns.append(f'第{i+1}个原子的{j}坐标')
    
    cdn = pd.DataFrame(data=cdn_data, columns=columns)
    ener = pd.DataFrame(data=ener_arr,columns=['能量'])
    # data 是一个坐标数据+能量的表明
    data = pd.concat([cdn,ener],axis=1)

    # print(data.describe())    # 描述 data 的 mean、std 等统计信息
    if ener_plot == False:
        data.iloc[:,:-1].plot.kde()    # 画出 所有坐标的kde
    else:
        data.iloc[:,-1].plot.kde()    # 画出能量的 kde
    plt.show()
    # data.to_excel("data.xlsx", index=False)   #保存为 excel 表格
    return data

def select_params(cdn_for_cluster):
    '''
    用于筛选 DBSCAN 的参数
    '''
    # 参数表格：eps[5,5.1,5.2,...,6.0], n=[25,26,27,28]
    for eps in [i*0.01 for i in range(50,60)]:
        for n in range(25,29):
            db = DBSCAN(eps=eps, min_samples=n).fit(cdn_for_cluster)
            # 输出 DBSCAN 聚类结果
            labels = db.labels_
            
            # 输出聚类簇数，噪声点，用于筛选 eps、min_samples
            # Number of clusters in labels, ignoring noise if present.
            # 聚类簇数
            n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0)
            # 噪声点总数
            n_noise_ = list(labels).count(-1)
            print(f'eps={eps}, n={n}')
            print('Estimated number of clusters: %d' % n_clusters_)
            print('Estimated number of noise points: %d' % n_noise_)

def clustering(cdn_for_cluster, eps, n):
    '''
    聚类，真正聚类啦，上一个是为了筛选参数。
    '''
    db = DBSCAN(eps=eps, min_samples=n).fit(cdn_for_cluster)
    labels = db.labels_
    
    cdn_for_cluster['label'] = labels
    
    # Number of clusters in labels, ignoring noise if present.
    n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0)
    n_noise_ = list(labels).count(-1)
    print(f'eps={eps}, n={n}')
    print('Estimated number of clusters: %d' % n_clusters_)
    print('Estimated number of noise points: %d' % n_noise_)
    
    # 分别记录各聚类簇下，所属的 Au20 原子数（共19980个）
    c = Counter()
    for cluster_num in labels:
        c[cluster_num] += 1
    # 输出聚类簇下的 Au20 原子数
    print(c)
    return cdn_for_cluster, labels, db

def plot_result(cdn_for_cluster, labels, exclude=[]):
    '''
    画图，所有原子，并标记他们所属的聚类簇。
    '''
    fig = plt.figure(figsize=(4,4))
    ax = fig.add_subplot(111, projection='3d')
    # 给坐标数据搞上其所属的聚类簇
    cdn_for_cluster['label'] = labels
    # 将那些需要排除的聚类簇，那些所属的原子数排除掉，以便画图时不用画出来
    cdn_for_cluster = cdn_for_cluster.loc[~cdn_for_cluster['label'].isin(exclude)]
    
    im = ax.scatter(xs=cdn_for_cluster.iloc[:,0], 
                ys=cdn_for_cluster.iloc[:,1], 
                zs=cdn_for_cluster.iloc[:,2], 
                c=cdn_for_cluster.iloc[:,-1],
                cmap=plt.get_cmap('gist_ncar'),
                alpha=0.5)
    plt.colorbar(im)
    plt.show()
    

def data_clu_process(labels, file_num=999, moleclue_num=20):
    '''
    输出每个团簇，每个聚类簇下所有的AU20原子数，并作为其输入数据。
    '''
    # 根据聚类簇，弄一个列名。
    num = len(set(labels))-1
    columns = [i for i in range(-1,num)]   
    # 构造一个空的 dataframe
    data_after = pd.DataFrame(columns=columns)
    
    for i in range(file_num):
        # 构造一个团簇的计数字典
        cnt = dict((k,0) for k in columns)
        # 团簇有 20 个原子，所以要 20 个、20 个来呀。用在第二题时，得修改一下
        tmp = labels[i:i+20]
        for cluster_num in tmp:
            # 计数
            cnt[cluster_num] += 1
        # 将计数字典，转换为 dataframe 数据。
        data_after = data_after.append(cnt, ignore_index=True)
    
    return data_after

def standardlize(data, plot=True):
    '''
    将数据进行标准化，标准化后方差为 1、均值为0.
    '''
    scaler = preprocessing.StandardScaler().fit(data)
    print(scaler.mean_)
    print(scaler.scale_)
    # 将标准化后的数据弄成一个 dataframe。
    data_standard = scaler.transform(data)
    data_standard = pd.DataFrame(data=data_standard, columns=['能量'])
    if plot:
        data_standard.plot.kde()  # 画出标准化后，能量数据的 KDE
        plt.show()
        
    return scaler, data_standard

def divide(data):
    '''
    将能量数据，按照箱型图，区分成 5 个等级（进行离散化）
    '''
    # 计算三等分点、IQR、中位数 和上界、下界等。
    q1 = float(np.round(data.quantile(0.25).values, 2))
    q3 = float(np.round(data.quantile(0.75).values, 2))
    iqr = float(np.round(q3 - q1, 2))
    med = float(np.round(data.median().values))
    top_critical = med + 1.5*iqr
    bottom_critical = med - 1.5*iqr
    
    # 对数据进行离散化
    bins = [bottom_critical, q1, q3, top_critical]
    data = np.digitize(data.iloc[:,0], bins=bins, right=True)
    # 保存离散化后的变量
    data_div = pd.DataFrame(data=data, columns=['能量'])
    # 对离散化后的能量数据，进行 one-hot 编码
    one_hot = LabelBinarizer() 
    data_one_hot = one_hot.fit_transform(data_div) 
    data_one_hot = pd.DataFrame(data=data_div, columns=one_hot.classes_)
    
    return data_one_hot, data_div, one_hot

def dump_data(one_hot, ener_one_hot, ener_div, ener_arr, scaler,
                db, cdn_labels, data_after_clu):
    '''
    用pickle格式，保存所有的数据
    '''
    pickle.dump(one_hot, open(r'.\model_and_data\one_hot.pkl','wb'))
    pickle.dump(ener_one_hot, open(r'.\model_and_data\ener_one_hot.pkl','wb'))
    pickle.dump(ener_div, open(r'.\model_and_data\ener_div.pkl','wb'))
    pickle.dump(ener_arr, open(r'.\model_and_data\ener_arr.pkl','wb'))
    pickle.dump(scaler, open(r'.\model_and_data\scaler.pkl', 'wb'))
    pickle.dump(db, open(r'.\model_and_data\dbscan.pkl', 'wb'))
    pickle.dump(cdn_labels, open(r'.\model_and_data\cdn_labels.pkl', 'wb'))
    pickle.dump(data_after_clu, open(r'.\model_and_data\data_after_clu.pkl', 'wb'))


if __name__ == '__main__':
    
    # 改变工作路径
    path = '..\题目\附件\Au20_OPT_1000'
    os.chdir(path)

    # 构建一个临时变量
    au_num_arr = []
    ener_arr = []
    cdn_data = []

    # 构建一个空的 dataframe
    cdn_for_cluster = pd.DataFrame(columns=['x','y','z'])
#   data_table(cdn_data, ener_arr)  #绘制所有的 AU20 的所有坐标的 KDE
    # 用于筛选 DBSCAN 的参数
#   select_params(cdn_for_cluster)

    for file in os.listdir():
    # 遍历 AU20_OPT_1000 文件夹下所有的 .xyz 文件
        # 读取所有的 .xyz 文件
        au_num, energy, cdn_vec, cdn_mat = read_data_from_path(file)
        
        # 构建名字变量（没用）
        au_num_arr.append(au_num)
        # 构建能量向量
        ener_arr.append(energy)
        # 构建坐标向量的向量
        cdn_data.append(cdn_vec)
        
        # 临时矩阵...（20X3）
        cdn_mat = pd.DataFrame(data=cdn_mat, columns=['x','y','z'])
        # 构建坐标矩阵（19980X3)
        cdn_for_cluster = pd.concat([cdn_for_cluster, cdn_mat], axis=0)
    
    # 将工作路径改回
    path = '..\..\..\代码'
    os.chdir(path)    
        
    # 将 list 转为 np.array
    cdn_data = np.array(cdn_data)
    # 转为 2D 的 np.array
    ener_arr = np.array(ener_arr).reshape((-1,1))
    
    # 设置 DBSCAN 的聚类参数
    eps = 0.5
    n = 28
    # 进行 DBSCAN 聚类，其中 db 是 DBSCAN 算法模型；
    # cdn_labels 是坐标数据+每个原子对应的聚类簇
    # labels 是每个原子所属的 聚类簇。
    cdn_labels, labels, db = clustering(cdn_for_cluster, eps, n)
    
    # 画出所有原子，并标记所属的聚类簇
#    plot_result(cdn_for_cluster, labels)
#   排除标记为 -1 的原子后，所有的原子
#    plot_result(cdn_for_cluster, labels, exclude=[-1]) 
#   排除标记为 -1 和0 的原子后，所有的原子。
#   plot_result(cdn_for_cluster, labels, exclude=[-1,0])

#   将数据处理为，从坐标型，聚类所属原子数型（定性分析）
    data_after_clu = data_clu_process(labels)

    data_table(cdn_data,ener_arr, ener_plot=True) #绘制 KDE（能量）
#   记得把 plot 改为 True，若要画 KDE 的话
    scaler, ener_standard = standardlize(ener_arr, plot=False) 
    # 计算标准化后的均值方差
#    mean = ener_standard.mean()
#    std = ener_standard.std()
#    print(f'标准化后的均值,方差为：{mean},{std}')

#   画箱型图
#    ener_standard.boxplot()
#    plt.show()
#   对能量数据离散处理、并进行 one-hot 编码
    ener_one_hot, ener_div, one_hot = divide(ener_standard)

#   保存数据
    dump_data(one_hot, ener_one_hot, ener_div, ener_arr, scaler,
                db, cdn_labels, data_after_clu)
